// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <unordered_map>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter.h"

#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE

#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>

#import "mediapipe/gpu/MPPMetalHelper.h"
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE

namespace
{
  constexpr int kNumInputTensorsWithAnchors = 3;
  constexpr int kNumCoordsPerBox = 4;

  constexpr char kTensorsTag[] = "TENSORS";
  constexpr char kTensorsGpuTag[] = "TENSORS_GPU";
} // namespace

namespace mediapipe
{

#if MEDIAPIPE_TFLITE_GL_INFERENCE
  using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
  using ::tflite::gpu::gl::GlShader;
  typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
  typedef id<MTLComputePipelineState> GpuProgram;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE

  namespace
  {

#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
    struct GPUData
    {
      GpuProgram decode_program;
      GpuProgram score_program;
      GpuTensor decoded_boxes_buffer;
      GpuTensor raw_boxes_buffer;
      GpuTensor raw_anchors_buffer;
      GpuTensor scored_boxes_buffer;
      GpuTensor raw_scores_buffer;
    };
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED

    void ConvertRawValuesToAnchors(const float *raw_anchors, int num_boxes,
                                   std::vector<Anchor> *anchors)
    {
      anchors->clear();
      for (int i = 0; i < num_boxes; ++i)
      {
        Anchor new_anchor;
        new_anchor.set_y_center(raw_anchors[i * kNumCoordsPerBox + 0]);
        new_anchor.set_x_center(raw_anchors[i * kNumCoordsPerBox + 1]);
        new_anchor.set_h(raw_anchors[i * kNumCoordsPerBox + 2]);
        new_anchor.set_w(raw_anchors[i * kNumCoordsPerBox + 3]);
        anchors->push_back(new_anchor);
      }
    }

    void ConvertAnchorsToRawValues(const std::vector<Anchor> &anchors,
                                   int num_boxes, float *raw_anchors)
    {
      ABSL_CHECK_EQ(anchors.size(), num_boxes);
      int box = 0;
      for (const auto &anchor : anchors)
      {
        raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center();
        raw_anchors[box * kNumCoordsPerBox + 1] = anchor.x_center();
        raw_anchors[box * kNumCoordsPerBox + 2] = anchor.h();
        raw_anchors[box * kNumCoordsPerBox + 3] = anchor.w();
        ++box;
      }
    }

  } // namespace

  // Convert result TFLite tensors from object detection models into MediaPipe
  // Detections.
  //
  // Input:
  //  TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. The vector of
  //               tensors can have 2 or 3 tensors. First tensor is the predicted
  //               raw boxes/keypoints. The size of the values must be (num_boxes
  //               * num_predicted_values). Second tensor is the score tensor. The
  //               size of the values must be (num_boxes * num_classes). It's
  //               optional to pass in a third tensor for anchors (e.g. for SSD
  //               models) depend on the outputs of the detection model. The size
  //               of anchor tensor must be (num_boxes * 4).
  //  TENSORS_GPU - vector of GlBuffer of MTLBuffer.
  // Output:
  //  DETECTIONS - Result MediaPipe detections.
  //
  // Usage example:
  // node {
  //   calculator: "TfLiteTensorsToDetectionsCalculator"
  //   input_stream: "TENSORS:tensors"
  //   input_side_packet: "ANCHORS:anchors"
  //   output_stream: "DETECTIONS:detections"
  //   options: {
  //     [mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] {
  //       num_classes: 91
  //       num_boxes: 1917
  //       num_coords: 4
  //       ignore_classes: [0, 1, 2]
  //       x_scale: 10.0
  //       y_scale: 10.0
  //       h_scale: 5.0
  //       w_scale: 5.0
  //     }
  //   }
  // }
  class TfLiteTensorsToDetectionsCalculator : public CalculatorBase
  {
  public:
    static absl::Status GetContract(CalculatorContract *cc);

    absl::Status Open(CalculatorContext *cc) override;
    absl::Status Process(CalculatorContext *cc) override;
    absl::Status Close(CalculatorContext *cc) override;

  private:
    absl::Status ProcessCPU(CalculatorContext *cc,
                            std::vector<Detection> *output_detections);
    absl::Status ProcessGPU(CalculatorContext *cc,
                            std::vector<Detection> *output_detections);

    absl::Status LoadOptions(CalculatorContext *cc);
    absl::Status GpuInit(CalculatorContext *cc);
    absl::Status DecodeBoxes(const float *raw_boxes,
                             const std::vector<Anchor> &anchors,
                             std::vector<float> *boxes);
    absl::Status ConvertToDetections(const float *detection_boxes,
                                     const float *detection_scores,
                                     const int *detection_classes,
                                     std::vector<Detection> *output_detections);
    Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
                                 float box_xmax, float score, int class_id,
                                 bool flip_vertically);

    int num_classes_ = 0;
    int num_boxes_ = 0;
    int num_coords_ = 0;
    std::set<int> ignore_classes_;

    ::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_;
    std::vector<Anchor> anchors_;
    bool side_packet_anchors_{};

#if MEDIAPIPE_TFLITE_GL_INFERENCE
    mediapipe::GlCalculatorHelper gpu_helper_;
    std::unique_ptr<GPUData> gpu_data_;
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
    MPPMetalHelper *gpu_helper_ = nullptr;
    std::unique_ptr<GPUData> gpu_data_;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE

    bool gpu_input_ = false;
    bool anchors_init_ = false;
  };
  REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);

  absl::Status TfLiteTensorsToDetectionsCalculator::GetContract(
      CalculatorContract *cc)
  {
    RET_CHECK(!cc->Inputs().GetTags().empty());
    RET_CHECK(!cc->Outputs().GetTags().empty());

    bool use_gpu = false;

    if (cc->Inputs().HasTag(kTensorsTag))
    {
      cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
    }

    if (cc->Inputs().HasTag(kTensorsGpuTag))
    {
      cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
      use_gpu |= true;
    }

    if (cc->Outputs().HasTag("DETECTIONS"))
    {
      cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
    }

    if (cc->InputSidePackets().UsesTags())
    {
      if (cc->InputSidePackets().HasTag("ANCHORS"))
      {
        cc->InputSidePackets().Tag("ANCHORS").Set<std::vector<Anchor>>();
      }
    }

    if (use_gpu)
    {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
      MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
      MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
    }

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::Open(CalculatorContext *cc)
  {
    cc->SetOffset(TimestampDiff(0));

    if (cc->Inputs().HasTag(kTensorsGpuTag))
    {
      gpu_input_ = true;
#if MEDIAPIPE_TFLITE_GL_INFERENCE
      MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
      gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
      RET_CHECK(gpu_helper_);
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
    }

    MP_RETURN_IF_ERROR(LoadOptions(cc));
    side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS");

    if (gpu_input_)
    {
      MP_RETURN_IF_ERROR(GpuInit(cc));
    }

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::Process(CalculatorContext *cc)
  {
    // ABSL_LOG(INFO) << "Process TfLiteTensorsToDetectionsCalculator";
    if ((!gpu_input_ && cc->Inputs().Tag(kTensorsTag).IsEmpty()) ||
        (gpu_input_ && cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()))
    {
      return absl::OkStatus();
    }

    auto output_detections = absl::make_unique<std::vector<Detection>>();

    if (gpu_input_)
    {
      MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
    }
    else
    {
      MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
    }

    // Output
    if (cc->Outputs().HasTag("DETECTIONS"))
    {
      cc->Outputs()
          .Tag("DETECTIONS")
          .Add(output_detections.release(), cc->InputTimestamp());
    }

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
      CalculatorContext *cc, std::vector<Detection> *output_detections)
  {
    const auto &input_tensors =
        cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();

    if (input_tensors.size() == 2 ||
        input_tensors.size() == kNumInputTensorsWithAnchors)
    {
      // Postprocessing on CPU for model without postprocessing op. E.g. output
      // raw score tensor and box tensor. Anchor decoding will be handled below.
      const TfLiteTensor *raw_box_tensor = &input_tensors[0];
      const TfLiteTensor *raw_score_tensor = &input_tensors[1];

      // TODO: Add flexible input tensor size handling.
      ABSL_CHECK_EQ(raw_box_tensor->dims->size, 3);
      ABSL_CHECK_EQ(raw_box_tensor->dims->data[0], 1);
      ABSL_CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_);
      ABSL_CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_);
      ABSL_CHECK_EQ(raw_score_tensor->dims->size, 3);
      ABSL_CHECK_EQ(raw_score_tensor->dims->data[0], 1);
      ABSL_CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_);
      ABSL_CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_);
      const float *raw_boxes = raw_box_tensor->data.f;
      const float *raw_scores = raw_score_tensor->data.f;
      std::vector<float> scores_vector(raw_score_tensor->data.f, raw_score_tensor->data.f + num_boxes_ * num_classes_);

      // TODO: Support other options to load anchors.
      if (!anchors_init_)
      {
        if (input_tensors.size() == kNumInputTensorsWithAnchors)
        {
          const TfLiteTensor *anchor_tensor = &input_tensors[2];
          ABSL_CHECK_EQ(anchor_tensor->dims->size, 2);
          ABSL_CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_);
          ABSL_CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
          const float *raw_anchors = anchor_tensor->data.f;
          ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
        }
        else if (side_packet_anchors_)
        {
          ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
          anchors_ =
              cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
        }
        else
        {
          return absl::UnavailableError("No anchor data available.");
        }
        anchors_init_ = true;
      }
      std::vector<float> boxes(num_boxes_ * num_coords_);
      MP_RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes));

      std::vector<float> detection_scores(num_boxes_);
      std::vector<int> detection_classes(num_boxes_);

      // Filter classes by scores.
      for (int i = 0; i < num_boxes_; ++i)
      {
        int class_id = -1;
        float max_score = -std::numeric_limits<float>::max();
        // Find the top score for box i.
        for (int score_idx = 0; score_idx < num_classes_; ++score_idx)
        {
          if (ignore_classes_.find(score_idx) == ignore_classes_.end())
          {
            auto score = scores_vector[i * num_classes_ + score_idx];
            if (options_.sigmoid_score())
            {
              if (options_.has_score_clipping_thresh())
              {
                score = score < -options_.score_clipping_thresh()
                            ? -options_.score_clipping_thresh()
                            : score;
                score = score > options_.score_clipping_thresh()
                            ? options_.score_clipping_thresh()
                            : score;
              }
              score = 1.0f / (1.0f + std::exp(-score));
            }
            if (max_score < score)
            {
              max_score = score;
              class_id = score_idx;
            }
          }
        }
        detection_scores[i] = max_score;
        detection_classes[i] = class_id;
      }

      MP_RETURN_IF_ERROR(
          ConvertToDetections(boxes.data(), detection_scores.data(),
                              detection_classes.data(), output_detections));
    }
    else
    {
      // Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
      // non-maximum suppression) within the model.
      RET_CHECK_EQ(input_tensors.size(), 4);

      const TfLiteTensor *detection_boxes_tensor = &input_tensors[0];
      const TfLiteTensor *detection_classes_tensor = &input_tensors[1];
      const TfLiteTensor *detection_scores_tensor = &input_tensors[2];
      const TfLiteTensor *num_boxes_tensor = &input_tensors[3];
      RET_CHECK_EQ(num_boxes_tensor->dims->size, 1);
      RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1);
      const float *num_boxes = num_boxes_tensor->data.f;
      num_boxes_ = num_boxes[0];
      RET_CHECK_EQ(detection_boxes_tensor->dims->size, 3);
      RET_CHECK_EQ(detection_boxes_tensor->dims->data[0], 1);
      const int max_detections = detection_boxes_tensor->dims->data[1];
      RET_CHECK_EQ(detection_boxes_tensor->dims->data[2], num_coords_);
      RET_CHECK_EQ(detection_classes_tensor->dims->size, 2);
      RET_CHECK_EQ(detection_classes_tensor->dims->data[0], 1);
      RET_CHECK_EQ(detection_classes_tensor->dims->data[1], max_detections);
      RET_CHECK_EQ(detection_scores_tensor->dims->size, 2);
      RET_CHECK_EQ(detection_scores_tensor->dims->data[0], 1);
      RET_CHECK_EQ(detection_scores_tensor->dims->data[1], max_detections);

      const float *detection_boxes = detection_boxes_tensor->data.f;
      const float *detection_scores = detection_scores_tensor->data.f;
      std::vector<int> detection_classes(num_boxes_);
      for (int i = 0; i < num_boxes_; ++i)
      {
        detection_classes[i] =
            static_cast<int>(detection_classes_tensor->data.f[i]);
      }
      MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
                                             detection_classes.data(),
                                             output_detections));
    }
    return absl::OkStatus();
  }
  absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
      CalculatorContext *cc, std::vector<Detection> *output_detections)
  {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
    const auto &input_tensors =
        cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
    RET_CHECK_GE(input_tensors.size(), 2);

    MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc,
                                                   &output_detections]()
                                                      -> absl::Status
                                                  {
    // Copy inputs.
    MP_RETURN_IF_ERROR(
        CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer));
    MP_RETURN_IF_ERROR(
        CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer));
    if (!anchors_init_) {
      if (side_packet_anchors_) {
        ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
        const auto& anchors =
            cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
        std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
        ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
        MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write<float>(
            absl::MakeSpan(raw_anchors)));
      } else {
        ABSL_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
        MP_RETURN_IF_ERROR(
            CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer));
      }
      anchors_init_ = true;
    }

    // Run shaders.
    // Decode boxes.
    MP_RETURN_IF_ERROR(gpu_data_->decoded_boxes_buffer.BindToIndex(0));
    MP_RETURN_IF_ERROR(gpu_data_->raw_boxes_buffer.BindToIndex(1));
    MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.BindToIndex(2));
    const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
    MP_RETURN_IF_ERROR(gpu_data_->decode_program.Dispatch(decode_workgroups));

    // Score boxes.
    MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.BindToIndex(0));
    MP_RETURN_IF_ERROR(gpu_data_->raw_scores_buffer.BindToIndex(1));
    const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
    MP_RETURN_IF_ERROR(gpu_data_->score_program.Dispatch(score_workgroups));

    // Copy decoded boxes from GPU to CPU.
    std::vector<float> boxes(num_boxes_ * num_coords_);
    MP_RETURN_IF_ERROR(
        gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes)));
    std::vector<float> score_class_id_pairs(num_boxes_ * 2);
    MP_RETURN_IF_ERROR(gpu_data_->scored_boxes_buffer.Read(
        absl::MakeSpan(score_class_id_pairs)));

    // TODO: b/138851969. Is it possible to output a float vector
    // for score and an int vector for class so that we can avoid copying twice?
    std::vector<float> detection_scores(num_boxes_);
    std::vector<int> detection_classes(num_boxes_);
    for (int i = 0; i < num_boxes_; ++i) {
      detection_scores[i] = score_class_id_pairs[i * 2];
      detection_classes[i] = static_cast<int>(score_class_id_pairs[i * 2 + 1]);
    }
    MP_RETURN_IF_ERROR(
        ConvertToDetections(boxes.data(), detection_scores.data(),
                            detection_classes.data(), output_detections));

    return absl::OkStatus(); }));
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE

    const auto &input_tensors =
        cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
    RET_CHECK_GE(input_tensors.size(), 2);

    // Copy inputs.
    [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_boxes_buffer
                               from:input_tensors[0]
                           blocking:false
                      commandBuffer:[gpu_helper_ commandBuffer]];
    [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_scores_buffer
                               from:input_tensors[1]
                           blocking:false
                      commandBuffer:[gpu_helper_ commandBuffer]];
    if (!anchors_init_)
    {
      if (side_packet_anchors_)
      {
        ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
        const auto &anchors =
            cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
        std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
        ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
        memcpy([gpu_data_->raw_anchors_buffer contents], raw_anchors.data(),
               raw_anchors.size() * sizeof(float));
      }
      else
      {
        RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
        [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_anchors_buffer
                                   from:input_tensors[2]
                               blocking:false
                          commandBuffer:[gpu_helper_ commandBuffer]];
      }
      anchors_init_ = true;
    }

    // Run shaders.
    id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
    command_buffer.label = @"TfLiteDecodeAndScoreBoxes";
    id<MTLComputeCommandEncoder> command_encoder =
        [command_buffer computeCommandEncoder];
    [command_encoder setComputePipelineState:gpu_data_->decode_program];
    [command_encoder setBuffer:gpu_data_->decoded_boxes_buffer
                        offset:0
                       atIndex:0];
    [command_encoder setBuffer:gpu_data_->raw_boxes_buffer offset:0 atIndex:1];
    [command_encoder setBuffer:gpu_data_->raw_anchors_buffer offset:0 atIndex:2];
    MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1);
    MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
    [command_encoder dispatchThreadgroups:decode_threadgroups
                    threadsPerThreadgroup:decode_threads_per_group];

    [command_encoder setComputePipelineState:gpu_data_->score_program];
    [command_encoder setBuffer:gpu_data_->scored_boxes_buffer offset:0 atIndex:0];
    [command_encoder setBuffer:gpu_data_->raw_scores_buffer offset:0 atIndex:1];
    MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
    MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
    [command_encoder dispatchThreadgroups:score_threadgroups
                    threadsPerThreadgroup:score_threads_per_group];
    [command_encoder endEncoding];
    [MPPMetalUtil commitCommandBufferAndWait:command_buffer];

    // Copy decoded boxes from GPU to CPU.
    std::vector<float> boxes(num_boxes_ * num_coords_);
    memcpy(boxes.data(), [gpu_data_->decoded_boxes_buffer contents],
           num_boxes_ * num_coords_ * sizeof(float));
    std::vector<float> score_class_id_pairs(num_boxes_ * 2);
    memcpy(score_class_id_pairs.data(), [gpu_data_->scored_boxes_buffer contents],
           num_boxes_ * 2 * sizeof(float));

    // Output detections.
    // TODO Adjust shader to avoid copying shader output twice.
    std::vector<float> detection_scores(num_boxes_);
    std::vector<int> detection_classes(num_boxes_);
    for (int i = 0; i < num_boxes_; ++i)
    {
      detection_scores[i] = score_class_id_pairs[i * 2];
      detection_classes[i] = static_cast<int>(score_class_id_pairs[i * 2 + 1]);
    }
    MP_RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
                                           detection_classes.data(),
                                           output_detections));

#else
    ABSL_LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::Close(CalculatorContext *cc)
  {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
    gpu_helper_.RunInGlContext([this]
                               { gpu_data_.reset(); });
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
    gpu_data_.reset();
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::LoadOptions(
      CalculatorContext *cc)
  {
    // Get calculator options specified in the graph.
    options_ =
        cc->Options<::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions>();

    num_classes_ = options_.num_classes();
    num_boxes_ = options_.num_boxes();
    num_coords_ = options_.num_coords();

    // Currently only support 2D when num_values_per_keypoint equals to 2.
    ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2);

    // Check if the output size is equal to the requested boxes and keypoints.
    ABSL_CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() +
                      kNumCoordsPerBox,
                  num_coords_);

    for (int i = 0; i < options_.ignore_classes_size(); ++i)
    {
      ignore_classes_.insert(options_.ignore_classes(i));
    }

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes(
      const float *raw_boxes, const std::vector<Anchor> &anchors,
      std::vector<float> *boxes)
  {
    for (int i = 0; i < num_boxes_; ++i)
    {
      const int box_offset = i * num_coords_ + options_.box_coord_offset();

      float y_center = raw_boxes[box_offset];
      float x_center = raw_boxes[box_offset + 1];
      float h = raw_boxes[box_offset + 2];
      float w = raw_boxes[box_offset + 3];
      if (options_.reverse_output_order())
      {
        x_center = raw_boxes[box_offset];
        y_center = raw_boxes[box_offset + 1];
        w = raw_boxes[box_offset + 2];
        h = raw_boxes[box_offset + 3];
      }

      x_center =
          x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
      y_center =
          y_center / options_.y_scale() * anchors[i].h() + anchors[i].y_center();

      if (options_.apply_exponential_on_box_size())
      {
        h = std::exp(h / options_.h_scale()) * anchors[i].h();
        w = std::exp(w / options_.w_scale()) * anchors[i].w();
      }
      else
      {
        h = h / options_.h_scale() * anchors[i].h();
        w = w / options_.w_scale() * anchors[i].w();
      }

      const float ymin = y_center - h / 2.f;
      const float xmin = x_center - w / 2.f;
      const float ymax = y_center + h / 2.f;
      const float xmax = x_center + w / 2.f;

      (*boxes)[i * num_coords_ + 0] = ymin;
      (*boxes)[i * num_coords_ + 1] = xmin;
      (*boxes)[i * num_coords_ + 2] = ymax;
      (*boxes)[i * num_coords_ + 3] = xmax;

      if (options_.num_keypoints())
      {
        for (int k = 0; k < options_.num_keypoints(); ++k)
        {
          const int offset = i * num_coords_ + options_.keypoint_coord_offset() +
                             k * options_.num_values_per_keypoint();

          float keypoint_y = raw_boxes[offset];
          float keypoint_x = raw_boxes[offset + 1];
          if (options_.reverse_output_order())
          {
            keypoint_x = raw_boxes[offset];
            keypoint_y = raw_boxes[offset + 1];
          }

          (*boxes)[offset] = keypoint_x / options_.x_scale() * anchors[i].w() +
                             anchors[i].x_center();
          (*boxes)[offset + 1] =
              keypoint_y / options_.y_scale() * anchors[i].h() +
              anchors[i].y_center();
        }
      }
    }

    return absl::OkStatus();
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections(
      const float *detection_boxes, const float *detection_scores,
      const int *detection_classes, std::vector<Detection> *output_detections)
  {
    for (int i = 0; i < num_boxes_; ++i)
    {
      if (options_.has_min_score_thresh() &&
          detection_scores[i] < options_.min_score_thresh())
      {
        continue;
      }
      const int box_offset = i * num_coords_;
      Detection detection = ConvertToDetection(
          detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
          detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
          detection_scores[i], detection_classes[i], options_.flip_vertically());
      const auto &bbox = detection.location_data().relative_bounding_box();
      if (bbox.width() < 0 || bbox.height() < 0)
      {
        // Decoded detection boxes could have negative values for width/height due
        // to model prediction. Filter out those boxes since some downstream
        // calculators may assume non-negative values. (b/171391719)
        continue;
      }

      // Add keypoints.
      if (options_.num_keypoints() > 0)
      {
        auto *location_data = detection.mutable_location_data();
        for (int kp_id = 0; kp_id < options_.num_keypoints() *
                                        options_.num_values_per_keypoint();
             kp_id += options_.num_values_per_keypoint())
        {
          auto keypoint = location_data->add_relative_keypoints();
          const int keypoint_index =
              box_offset + options_.keypoint_coord_offset() + kp_id;
          keypoint->set_x(detection_boxes[keypoint_index + 0]);
          keypoint->set_y(options_.flip_vertically()
                              ? 1.f - detection_boxes[keypoint_index + 1]
                              : detection_boxes[keypoint_index + 1]);
        }
      }
      output_detections->emplace_back(detection);
    }
    return absl::OkStatus();
  }

  Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
      float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
      int class_id, bool flip_vertically)
  {
    Detection detection;
    detection.add_score(score);
    detection.add_label_id(class_id);

    LocationData *location_data = detection.mutable_location_data();
    location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);

    LocationData::RelativeBoundingBox *relative_bbox =
        location_data->mutable_relative_bounding_box();

    relative_bbox->set_xmin(box_xmin);
    relative_bbox->set_ymin(flip_vertically ? 1.f - box_ymax : box_ymin);
    relative_bbox->set_width(box_xmax - box_xmin);
    relative_bbox->set_height(box_ymax - box_ymin);
    return detection;
  }

  absl::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
      CalculatorContext *cc)
  {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
    MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status
                                                  {
    gpu_data_ = absl::make_unique<GPUData>();

    // A shader to decode detection boxes.
    const std::string decode_src = absl::Substitute(
        R"( #version 310 es

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

layout(location = 0) uniform vec4 scale;

layout(std430, binding = 0) writeonly buffer Output {
  float data[];
} boxes;

layout(std430, binding = 1) readonly buffer Input0 {
  float data[];
} raw_boxes;

layout(std430, binding = 2) readonly buffer Input1 {
  float data[];
} raw_anchors;

uint num_coords = uint($0);
int reverse_output_order = int($1);
int apply_exponential = int($2);
int box_coord_offset = int($3);
int num_keypoints = int($4);
int keypt_coord_offset = int($5);
int num_values_per_keypt = int($6);

void main() {
  uint g_idx = gl_GlobalInvocationID.x;  // box index
  uint box_offset = g_idx * num_coords + uint(box_coord_offset);
  uint anchor_offset = g_idx * uint(4);  // check kNumCoordsPerBox

  float y_center, x_center, h, w;

  if (reverse_output_order == int(0)) {
    y_center = raw_boxes.data[box_offset + uint(0)];
    x_center = raw_boxes.data[box_offset + uint(1)];
    h = raw_boxes.data[box_offset + uint(2)];
    w = raw_boxes.data[box_offset + uint(3)];
  } else {
    x_center = raw_boxes.data[box_offset + uint(0)];
    y_center = raw_boxes.data[box_offset + uint(1)];
    w = raw_boxes.data[box_offset + uint(2)];
    h = raw_boxes.data[box_offset + uint(3)];
  }

  float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
  float anchor_xc = raw_anchors.data[anchor_offset + uint(1)];
  float anchor_h  = raw_anchors.data[anchor_offset + uint(2)];
  float anchor_w  = raw_anchors.data[anchor_offset + uint(3)];

  x_center = x_center / scale.x * anchor_w + anchor_xc;
  y_center = y_center / scale.y * anchor_h + anchor_yc;

  if (apply_exponential == int(1)) {
    h = exp(h / scale.w) * anchor_h;
    w = exp(w / scale.z) * anchor_w;
  } else {
    h = (h / scale.w) * anchor_h;
    w = (w / scale.z) * anchor_w;
  }

  float ymin = y_center - h / 2.0;
  float xmin = x_center - w / 2.0;
  float ymax = y_center + h / 2.0;
  float xmax = x_center + w / 2.0;

  boxes.data[box_offset + uint(0)] = ymin;
  boxes.data[box_offset + uint(1)] = xmin;
  boxes.data[box_offset + uint(2)] = ymax;
  boxes.data[box_offset + uint(3)] = xmax;

  if (num_keypoints > int(0)){
    for (int k = 0; k < num_keypoints; ++k) {
      int kp_offset =
        int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
      float kp_y, kp_x;
      if (reverse_output_order == int(0)) {
        kp_y = raw_boxes.data[kp_offset + int(0)];
        kp_x = raw_boxes.data[kp_offset + int(1)];
      } else {
        kp_x = raw_boxes.data[kp_offset + int(0)];
        kp_y = raw_boxes.data[kp_offset + int(1)];
      }
      boxes.data[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc;
      boxes.data[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc;
    }
  }
})",
        options_.num_coords(),  // box xywh
        options_.reverse_output_order() ? 1 : 0,
        options_.apply_exponential_on_box_size() ? 1 : 0,
        options_.box_coord_offset(), options_.num_keypoints(),
        options_.keypoint_coord_offset(), options_.num_values_per_keypoint());

    // Shader program
    GlShader decode_shader;
    MP_RETURN_IF_ERROR(
        GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader));
    MP_RETURN_IF_ERROR(GpuProgram::CreateWithShader(
        decode_shader, &gpu_data_->decode_program));
    // Outputs
    size_t decoded_boxes_length = num_boxes_ * num_coords_;
    MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
        decoded_boxes_length, &gpu_data_->decoded_boxes_buffer));
    // Inputs
    size_t raw_boxes_length = num_boxes_ * num_coords_;
    MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
        raw_boxes_length, &gpu_data_->raw_boxes_buffer));
    size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
    MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
        raw_anchors_length, &gpu_data_->raw_anchors_buffer));
    // Parameters
    glUseProgram(gpu_data_->decode_program.id());
    glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(),
                options_.h_scale());

    // A shader to score detection boxes.
    const std::string score_src = absl::Substitute(
        R"( #version 310 es

layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in;

#define FLT_MAX 1.0e+37

shared float local_scores[$0];

layout(std430, binding = 0) writeonly buffer Output {
  float data[];
} scored_boxes;

layout(std430, binding = 1) readonly buffer Input0 {
  float data[];
} raw_scores;

uint num_classes = uint($0);
int apply_sigmoid = int($1);
int apply_clipping_thresh = int($2);
float clipping_thresh = float($3);
int ignore_class_0 = int($4);

float optional_sigmoid(float x) {
  if (apply_sigmoid == int(0)) return x;
  if (apply_clipping_thresh == int(1)) {
    x = clamp(x, -clipping_thresh, clipping_thresh);
  }
  x = 1.0 / (1.0 + exp(-x));
  return x;
}

void main() {
  uint g_idx = gl_GlobalInvocationID.x;   // box idx
  uint s_idx =  gl_LocalInvocationID.y;   // score/class idx

  // load all scores into shared memory
  float score = raw_scores.data[g_idx * num_classes + s_idx];
  local_scores[s_idx] = optional_sigmoid(score);
  memoryBarrierShared();
  barrier();

  // find max score in shared memory
  if (s_idx == uint(0)) {
    float max_score = -FLT_MAX;
    float max_class = -1.0;
    for (int i=ignore_class_0; i<int(num_classes); ++i) {
      if (local_scores[i] > max_score) {
        max_score = local_scores[i];
        max_class = float(i);
      }
    }
    scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score;
    scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class;
  }
})",
        num_classes_, options_.sigmoid_score() ? 1 : 0,
        options_.has_score_clipping_thresh() ? 1 : 0,
        options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
                                             : 0,
        !ignore_classes_.empty() ? 1 : 0);

    // # filter classes supported is hardware dependent.
    int max_wg_size;  //  typically <= 1024
    glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
                    &max_wg_size);  // y-dim
    ABSL_CHECK_LT(num_classes_, max_wg_size)
        << "# classes must be < " << max_wg_size;
    // TODO support better filtering.
    ABSL_CHECK_LE(ignore_classes_.size(), 1)
        << "Only ignore class 0 is allowed";

    // Shader program
    GlShader score_shader;
    MP_RETURN_IF_ERROR(
        GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader));
    MP_RETURN_IF_ERROR(
        GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program));
    // Outputs
    size_t scored_boxes_length = num_boxes_ * 2;  // score, class
    MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
        scored_boxes_length, &gpu_data_->scored_boxes_buffer));
    // Inputs
    size_t raw_scores_length = num_boxes_ * num_classes_;
    MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
        raw_scores_length, &gpu_data_->raw_scores_buffer));

    return absl::OkStatus(); }));

#elif MEDIAPIPE_TFLITE_METAL_INFERENCE

    gpu_data_ = absl::make_unique<GPUData>();
    id<MTLDevice> device = gpu_helper_.mtlDevice;

    // A shader to decode detection boxes.
    std::string decode_src = absl::Substitute(
        R"(
#include <metal_stdlib>

using namespace metal;

kernel void decodeKernel(
    device float*                   boxes       [[ buffer(0) ]],
    device float*                   raw_boxes   [[ buffer(1) ]],
    device float*                   raw_anchors [[ buffer(2) ]],
    uint2                           gid         [[ thread_position_in_grid ]]) {

  uint num_coords = uint($0);
  int reverse_output_order = int($1);
  int apply_exponential = int($2);
  int box_coord_offset = int($3);
  int num_keypoints = int($4);
  int keypt_coord_offset = int($5);
  int num_values_per_keypt = int($6);
)",
        options_.num_coords(), // box xywh
        options_.reverse_output_order() ? 1 : 0,
        options_.apply_exponential_on_box_size() ? 1 : 0,
        options_.box_coord_offset(), options_.num_keypoints(),
        options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
    decode_src += absl::Substitute(
        R"(
  float4 scale = float4(($0),($1),($2),($3));
)",
        options_.x_scale(), options_.y_scale(), options_.w_scale(),
        options_.h_scale());
    decode_src += R"(
  uint g_idx = gid.x;
  uint box_offset = g_idx * num_coords + uint(box_coord_offset);
  uint anchor_offset = g_idx * uint(4);  // check kNumCoordsPerBox

  float y_center, x_center, h, w;

  if (reverse_output_order == int(0)) {
    y_center = raw_boxes[box_offset + uint(0)];
    x_center = raw_boxes[box_offset + uint(1)];
    h = raw_boxes[box_offset + uint(2)];
    w = raw_boxes[box_offset + uint(3)];
  } else {
    x_center = raw_boxes[box_offset + uint(0)];
    y_center = raw_boxes[box_offset + uint(1)];
    w = raw_boxes[box_offset + uint(2)];
    h = raw_boxes[box_offset + uint(3)];
  }

  float anchor_yc = raw_anchors[anchor_offset + uint(0)];
  float anchor_xc = raw_anchors[anchor_offset + uint(1)];
  float anchor_h  = raw_anchors[anchor_offset + uint(2)];
  float anchor_w  = raw_anchors[anchor_offset + uint(3)];

  x_center = x_center / scale.x * anchor_w + anchor_xc;
  y_center = y_center / scale.y * anchor_h + anchor_yc;

  if (apply_exponential == int(1)) {
    h = exp(h / scale.w) * anchor_h;
    w = exp(w / scale.z) * anchor_w;
  } else {
    h = (h / scale.w) * anchor_h;
    w = (w / scale.z) * anchor_w;
  }

  float ymin = y_center - h / 2.0;
  float xmin = x_center - w / 2.0;
  float ymax = y_center + h / 2.0;
  float xmax = x_center + w / 2.0;

  boxes[box_offset + uint(0)] = ymin;
  boxes[box_offset + uint(1)] = xmin;
  boxes[box_offset + uint(2)] = ymax;
  boxes[box_offset + uint(3)] = xmax;

  if (num_keypoints > int(0)){
    for (int k = 0; k < num_keypoints; ++k) {
      int kp_offset =
        int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
      float kp_y, kp_x;
      if (reverse_output_order == int(0)) {
        kp_y = raw_boxes[kp_offset + int(0)];
        kp_x = raw_boxes[kp_offset + int(1)];
      } else {
        kp_x = raw_boxes[kp_offset + int(0)];
        kp_y = raw_boxes[kp_offset + int(1)];
      }
      boxes[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc;
      boxes[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc;
    }
  }
})";

    {
      // Shader program
      NSString *library_source =
          [NSString stringWithUTF8String:decode_src.c_str()];
      NSError *error = nil;
      id<MTLLibrary> library = [device newLibraryWithSource:library_source
                                                    options:nullptr
                                                      error:&error];
      RET_CHECK(library != nil) << "Couldn't create shader library "
                                << [[error localizedDescription] UTF8String];
      id<MTLFunction> kernel_func = nil;
      kernel_func = [library newFunctionWithName:@"decodeKernel"];
      RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
      gpu_data_->decode_program =
          [device newComputePipelineStateWithFunction:kernel_func
                                                error:&error];
      RET_CHECK(gpu_data_->decode_program != nil)
          << "Couldn't create pipeline state "
          << [[error localizedDescription] UTF8String];
      // Outputs
      size_t decoded_boxes_length = num_boxes_ * num_coords_ * sizeof(float);
      gpu_data_->decoded_boxes_buffer =
          [device newBufferWithLength:decoded_boxes_length
                              options:MTLResourceStorageModeShared];
      // Inputs
      size_t raw_boxes_length = num_boxes_ * num_coords_ * sizeof(float);
      gpu_data_->raw_boxes_buffer =
          [device newBufferWithLength:raw_boxes_length
                              options:MTLResourceStorageModeShared];
      size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox * sizeof(float);
      gpu_data_->raw_anchors_buffer =
          [device newBufferWithLength:raw_anchors_length
                              options:MTLResourceStorageModeShared];
    }

    // A shader to score detection boxes.
    const std::string score_src = absl::Substitute(
        R"(
#include <metal_stdlib>

using namespace metal;

float optional_sigmoid(float x) {
  int apply_sigmoid = int($1);
  int apply_clipping_thresh = int($2);
  float clipping_thresh = float($3);
  if (apply_sigmoid == int(0)) return x;
  if (apply_clipping_thresh == int(1)) {
    x = clamp(x, -clipping_thresh, clipping_thresh);
  }
  x = 1.0 / (1.0 + exp(-x));
  return x;
}

kernel void scoreKernel(
    device float*             scored_boxes [[ buffer(0) ]],
    device float*             raw_scores   [[ buffer(1) ]],
    uint2                     tid          [[ thread_position_in_threadgroup ]],
    uint2                     gid          [[ thread_position_in_grid ]]) {

  uint num_classes = uint($0);
  int apply_sigmoid = int($1);
  int apply_clipping_thresh = int($2);
  float clipping_thresh = float($3);
  int ignore_class_0 = int($4);

  uint g_idx = gid.x;   // box idx
  uint s_idx = tid.y;   // score/class idx

  // load all scores into shared memory
  threadgroup float local_scores[$0];
  float score = raw_scores[g_idx * num_classes + s_idx];
  local_scores[s_idx] = optional_sigmoid(score);
  threadgroup_barrier(mem_flags::mem_threadgroup);

  // find max score in shared memory
  if (s_idx == uint(0)) {
    float max_score = -FLT_MAX;
    float max_class = -1.0;
    for (int i=ignore_class_0; i<int(num_classes); ++i) {
      if (local_scores[i] > max_score) {
        max_score = local_scores[i];
        max_class = float(i);
      }
    }
    scored_boxes[g_idx * uint(2) + uint(0)] = max_score;
    scored_boxes[g_idx * uint(2) + uint(1)] = max_class;
  }
})",
        num_classes_, options_.sigmoid_score() ? 1 : 0,
        options_.has_score_clipping_thresh() ? 1 : 0,
        options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
                                             : 0,
        ignore_classes_.size() ? 1 : 0);

    // TODO support better filtering.
    ABSL_CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";

    {
      // Shader program
      NSString *library_source =
          [NSString stringWithUTF8String:score_src.c_str()];
      NSError *error = nil;
      id<MTLLibrary> library = [device newLibraryWithSource:library_source
                                                    options:nullptr
                                                      error:&error];
      RET_CHECK(library != nil) << "Couldn't create shader library "
                                << [[error localizedDescription] UTF8String];
      id<MTLFunction> kernel_func = nil;
      kernel_func = [library newFunctionWithName:@"scoreKernel"];
      RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
      gpu_data_->score_program =
          [device newComputePipelineStateWithFunction:kernel_func
                                                error:&error];
      RET_CHECK(gpu_data_->score_program != nil)
          << "Couldn't create pipeline state "
          << [[error localizedDescription] UTF8String];
      // Outputs
      size_t scored_boxes_length = num_boxes_ * 2 * sizeof(float); // score,class
      gpu_data_->scored_boxes_buffer =
          [device newBufferWithLength:scored_boxes_length
                              options:MTLResourceStorageModeShared];
      // Inputs
      size_t raw_scores_length = num_boxes_ * num_classes_ * sizeof(float);
      gpu_data_->raw_scores_buffer =
          [device newBufferWithLength:raw_scores_length
                              options:MTLResourceStorageModeShared];
      // # filter classes supported is hardware dependent.
      int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup;
      ABSL_CHECK_LT(num_classes_, max_wg_size)
          << "# classes must be <" << max_wg_size;
    }

#endif // MEDIAPIPE_TFLITE_GL_INFERENCE

    return absl::OkStatus();
  }

} // namespace mediapipe
